{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "3bca1827-a930-4a72-b819-394544442e52",
   "metadata": {},
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "import os\n",
    "import openai\n",
    "import json\n",
    "import groq\n",
    "from sklearn.metrics import matthews_corrcoef, accuracy_score, f1_score\n",
    "#from datasets import load_dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "04c00ff8-459a-4789-9e93-637d1f02820c",
   "metadata": {},
   "outputs": [],
   "source": [
    "docs = pd.read_csv('../data/polnli_test_results.csv')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "3d3e3066-594b-4101-b45a-70576e5d0195",
   "metadata": {},
   "outputs": [],
   "source": [
    "def metrics(df, preds, group_by=None):\n",
    "    true_col = 'entailment'\n",
    "    \n",
    "    def get_metrics(y_true, y_pred):\n",
    "        return {\n",
    "            'MCC': matthews_corrcoef(y_true, y_pred),\n",
    "            'Accuracy': accuracy_score(y_true, y_pred),\n",
    "            'F1': f1_score(y_true, y_pred, average='weighted')\n",
    "        }\n",
    "    \n",
    "    results = []\n",
    "    \n",
    "    if group_by not in ['dataset', 'task']:\n",
    "        for col in preds:\n",
    "            metrics = get_metrics(df[true_col], df[col])\n",
    "            metrics['Column'] = col\n",
    "            results.append(metrics)\n",
    "    else:\n",
    "        for col in preds:\n",
    "            for group_name, group in df.groupby(group_by):\n",
    "                metrics = get_metrics(group[true_col], group[col])\n",
    "                metrics['Column'] = col\n",
    "                metrics[group_by.capitalize()] = group_name\n",
    "                results.append(metrics)\n",
    "    \n",
    "    results_df = pd.DataFrame(results)\n",
    "    \n",
    "    if group_by in ['dataset', 'task']:\n",
    "        return results_df.set_index(['Column', group_by.capitalize()])\n",
    "    else:\n",
    "        return results_df.set_index('Column')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d40ba604-a40b-485e-908d-9467668b468d",
   "metadata": {},
   "source": [
    "# Create Groq API"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 66,
   "id": "497e5c34-f1f1-43e5-88bd-77696b0fc52c",
   "metadata": {},
   "outputs": [],
   "source": [
    "with open('../groq_key.txt', 'r') as file:\n",
    "    api_key = file.read().strip()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 63,
   "id": "822335d4-ced6-4b18-998c-1be13b6938d7",
   "metadata": {},
   "outputs": [],
   "source": [
    "user_message = \"\"\"You are a classifier that can only respond with 0 or 1. I'm going to show you a short text sample and I want you to determine if {hypothesis}. Here is the text:\n",
    "{doc}\n",
    "\n",
    "If it is true that {hypothesis}, return 0. If it is not true that {hypothesis}, return 1.\n",
    "Do not explain your answer, and only return 0 or 1.\n",
    "\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 72,
   "id": "a08ada3b-a6ba-4a96-be2f-466d685975fa",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 1min 3s, sys: 11.8 s, total: 1min 14s\n",
      "Wall time: 1h 1min 41s\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "<timed exec>:28: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "from groq import Groq\n",
    "data = docs\n",
    "labels = []\n",
    "model=\"llama-3.3-70b-versatile\"\n",
    "\n",
    "client = Groq(api_key = api_key)\n",
    "\n",
    "for i in data.index:\n",
    "    doc = data.loc[i, 'premise']\n",
    "    hypothesis = data.loc[i, 'augmented_hypothesis']\n",
    "    res = client.chat.completions.create(\n",
    "        messages=[\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": user_message.format(doc = doc, hypothesis = hypothesis),\n",
    "            }\n",
    "        ],\n",
    "    \n",
    "        # The language model which will generate the completion.\n",
    "        model=model,\n",
    "        temperature = 0,\n",
    "        max_completion_tokens = 2\n",
    "    )\n",
    "\n",
    "    labels.append(res.choices[0].message.content)\n",
    "\n",
    "docs['llama70b'] = labels\n",
    "docs['llama70b'] = docs['llama70b'].replace({'1':1, '0':0, '2':-1})\n",
    "docs.to_csv('polnli_docs_results.csv', index = False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5bfda9c5-53e7-4246-bf09-cf18617342dc",
   "metadata": {},
   "source": [
    "# Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 76,
   "id": "680ee7a4-04b0-47bd-a97d-46277f9a5c18",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/var/folders/1q/xggcl4hn6mx_q8dfxbx_whpr0000gq/T/ipykernel_24978/10051362.py:2: FutureWarning: Downcasting behavior in `replace` is deprecated and will be removed in a future version. To retain the old behavior, explicitly call `result.infer_objects(copy=False)`. To opt-in to the future behavior, set `pd.set_option('future.no_silent_downcasting', True)`\n",
      "  docs['llama70b'] = docs['llama70b'].replace({'1':1, '0':0, '2':-1})\n"
     ]
    }
   ],
   "source": [
    "docs['llama70b'] = labels\n",
    "docs['llama70b'] = docs['llama70b'].replace({'1':1, '0':0, '2':-1})\n",
    "docs.to_csv('../data/polnli_test_results.csv', index = False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "id": "486b5655-646a-4e81-8207-04af3e97fca9",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>premise</th>\n",
       "      <th>hypothesis</th>\n",
       "      <th>entailment</th>\n",
       "      <th>dataset</th>\n",
       "      <th>task</th>\n",
       "      <th>augmented_hypothesis</th>\n",
       "      <th>base_nli</th>\n",
       "      <th>large_nli</th>\n",
       "      <th>llama</th>\n",
       "      <th>sonnet</th>\n",
       "      <th>base_debate</th>\n",
       "      <th>large_debate</th>\n",
       "      <th>base_modern</th>\n",
       "      <th>large_modern</th>\n",
       "      <th>llama70b</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>0</th>\n",
       "      <td>Transport workers strike to protest rising fue...</td>\n",
       "      <td>The event described in this text is a strike.</td>\n",
       "      <td>0</td>\n",
       "      <td>mlburnham/scad_event_entailment</td>\n",
       "      <td>event extraction</td>\n",
       "      <td>the incident mentioned in this text is a strike.</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>1</th>\n",
       "      <td>Municipal workers strike over pay.</td>\n",
       "      <td>The event described in this text is a strike.</td>\n",
       "      <td>0</td>\n",
       "      <td>mlburnham/scad_event_entailment</td>\n",
       "      <td>event extraction</td>\n",
       "      <td>the occurrence detailed in this passage is a s...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>2</th>\n",
       "      <td>Niger's mining sector strikes.</td>\n",
       "      <td>The event described in this text is a strike.</td>\n",
       "      <td>0</td>\n",
       "      <td>mlburnham/scad_event_entailment</td>\n",
       "      <td>event extraction</td>\n",
       "      <td>the event described in this text is a strike.</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>3</th>\n",
       "      <td>Separatist movement protests detention of lead...</td>\n",
       "      <td>The event described in this text is a strike.</td>\n",
       "      <td>0</td>\n",
       "      <td>mlburnham/scad_event_entailment</td>\n",
       "      <td>event extraction</td>\n",
       "      <td>the event described in this text is a strike.</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>4</th>\n",
       "      <td>Janitors and hospital support staff staged a s...</td>\n",
       "      <td>The event described in this text is a strike.</td>\n",
       "      <td>0</td>\n",
       "      <td>mlburnham/scad_event_entailment</td>\n",
       "      <td>event extraction</td>\n",
       "      <td>the occurrence detailed in this passage is a s...</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>0</td>\n",
       "      <td>1</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                             premise  \\\n",
       "0  Transport workers strike to protest rising fue...   \n",
       "1                 Municipal workers strike over pay.   \n",
       "2                     Niger's mining sector strikes.   \n",
       "3  Separatist movement protests detention of lead...   \n",
       "4  Janitors and hospital support staff staged a s...   \n",
       "\n",
       "                                      hypothesis  entailment  \\\n",
       "0  The event described in this text is a strike.           0   \n",
       "1  The event described in this text is a strike.           0   \n",
       "2  The event described in this text is a strike.           0   \n",
       "3  The event described in this text is a strike.           0   \n",
       "4  The event described in this text is a strike.           0   \n",
       "\n",
       "                           dataset              task  \\\n",
       "0  mlburnham/scad_event_entailment  event extraction   \n",
       "1  mlburnham/scad_event_entailment  event extraction   \n",
       "2  mlburnham/scad_event_entailment  event extraction   \n",
       "3  mlburnham/scad_event_entailment  event extraction   \n",
       "4  mlburnham/scad_event_entailment  event extraction   \n",
       "\n",
       "                                augmented_hypothesis  base_nli  large_nli  \\\n",
       "0   the incident mentioned in this text is a strike.         0          0   \n",
       "1  the occurrence detailed in this passage is a s...         0          0   \n",
       "2      the event described in this text is a strike.         0          0   \n",
       "3      the event described in this text is a strike.         0          0   \n",
       "4  the occurrence detailed in this passage is a s...         0          0   \n",
       "\n",
       "   llama  sonnet  base_debate  large_debate  base_modern  large_modern  \\\n",
       "0      0       0            0             0            0             0   \n",
       "1      0       0            0             0            0             0   \n",
       "2      0       0            0             0            0             0   \n",
       "3      0       0            0             0            0             0   \n",
       "4      0       0            0             0            0             0   \n",
       "\n",
       "  llama70b  \n",
       "0        1  \n",
       "1        1  \n",
       "2        1  \n",
       "3        1  \n",
       "4        1  "
      ]
     },
     "execution_count": 60,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "docs.head()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "id": "aaf114c3-aa2c-4232-9332-0e34804de5c7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th>MCC</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Column</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th>base_nli</th>\n",
       "      <td>0.657027</td>\n",
       "      <td>0.834375</td>\n",
       "      <td>0.830335</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>large_nli</th>\n",
       "      <td>0.718800</td>\n",
       "      <td>0.863074</td>\n",
       "      <td>0.859911</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>base_debate</th>\n",
       "      <td>0.892088</td>\n",
       "      <td>0.947872</td>\n",
       "      <td>0.947670</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>large_debate</th>\n",
       "      <td>0.915911</td>\n",
       "      <td>0.959326</td>\n",
       "      <td>0.959180</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>llama</th>\n",
       "      <td>0.730997</td>\n",
       "      <td>0.862358</td>\n",
       "      <td>0.863467</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>llama70b</th>\n",
       "      <td>0.804639</td>\n",
       "      <td>0.902968</td>\n",
       "      <td>0.903514</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>sonnet</th>\n",
       "      <td>0.815902</td>\n",
       "      <td>0.910517</td>\n",
       "      <td>0.909423</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                   MCC  Accuracy        F1\n",
       "Column                                    \n",
       "base_nli      0.657027  0.834375  0.830335\n",
       "large_nli     0.718800  0.863074  0.859911\n",
       "base_debate   0.892088  0.947872  0.947670\n",
       "large_debate  0.915911  0.959326  0.959180\n",
       "llama         0.730997  0.862358  0.863467\n",
       "llama70b      0.804639  0.902968  0.903514\n",
       "sonnet        0.815902  0.910517  0.909423"
      ]
     },
     "execution_count": 73,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics(docs, preds = ['base_nli', 'large_nli', 'base_debate', 'large_debate', 'llama', 'llama70b', 'sonnet'], group_by = None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 74,
   "id": "a13493d1-8ae4-41d9-8be6-55f0de2759a7",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<div>\n",
       "<style scoped>\n",
       "    .dataframe tbody tr th:only-of-type {\n",
       "        vertical-align: middle;\n",
       "    }\n",
       "\n",
       "    .dataframe tbody tr th {\n",
       "        vertical-align: top;\n",
       "    }\n",
       "\n",
       "    .dataframe thead th {\n",
       "        text-align: right;\n",
       "    }\n",
       "</style>\n",
       "<table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       "    <tr style=\"text-align: right;\">\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th>MCC</th>\n",
       "      <th>Accuracy</th>\n",
       "      <th>F1</th>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>Column</th>\n",
       "      <th>Task</th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "      <th></th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">base_nli</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.538591</td>\n",
       "      <td>0.753841</td>\n",
       "      <td>0.753169</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.550569</td>\n",
       "      <td>0.858095</td>\n",
       "      <td>0.845361</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.530711</td>\n",
       "      <td>0.775285</td>\n",
       "      <td>0.770202</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.871001</td>\n",
       "      <td>0.935212</td>\n",
       "      <td>0.934550</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">large_nli</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.723042</td>\n",
       "      <td>0.852304</td>\n",
       "      <td>0.852599</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.553551</td>\n",
       "      <td>0.854430</td>\n",
       "      <td>0.848143</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.585007</td>\n",
       "      <td>0.797717</td>\n",
       "      <td>0.789429</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.896400</td>\n",
       "      <td>0.948081</td>\n",
       "      <td>0.947663</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">base_debate</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.765923</td>\n",
       "      <td>0.878492</td>\n",
       "      <td>0.878934</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.856644</td>\n",
       "      <td>0.950700</td>\n",
       "      <td>0.950405</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.938404</td>\n",
       "      <td>0.970158</td>\n",
       "      <td>0.970135</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.929391</td>\n",
       "      <td>0.965387</td>\n",
       "      <td>0.965351</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">large_debate</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.819049</td>\n",
       "      <td>0.909218</td>\n",
       "      <td>0.909492</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.882548</td>\n",
       "      <td>0.959694</td>\n",
       "      <td>0.959374</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.969009</td>\n",
       "      <td>0.984979</td>\n",
       "      <td>0.984972</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.924496</td>\n",
       "      <td>0.962503</td>\n",
       "      <td>0.962322</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">llama</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.808244</td>\n",
       "      <td>0.905726</td>\n",
       "      <td>0.905480</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.559060</td>\n",
       "      <td>0.782145</td>\n",
       "      <td>0.799067</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.605609</td>\n",
       "      <td>0.798117</td>\n",
       "      <td>0.799651</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.918734</td>\n",
       "      <td>0.959396</td>\n",
       "      <td>0.959505</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">llama70b</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.844209</td>\n",
       "      <td>0.921788</td>\n",
       "      <td>0.922018</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.600318</td>\n",
       "      <td>0.815456</td>\n",
       "      <td>0.828349</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.762784</td>\n",
       "      <td>0.882235</td>\n",
       "      <td>0.882859</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.944122</td>\n",
       "      <td>0.972265</td>\n",
       "      <td>0.972317</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th rowspan=\"4\" valign=\"top\">sonnet</th>\n",
       "      <th>event extraction</th>\n",
       "      <td>0.784838</td>\n",
       "      <td>0.880936</td>\n",
       "      <td>0.881063</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>hatespeech and toxicity</th>\n",
       "      <td>0.571282</td>\n",
       "      <td>0.862425</td>\n",
       "      <td>0.853626</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>stance detection</th>\n",
       "      <td>0.791883</td>\n",
       "      <td>0.899259</td>\n",
       "      <td>0.898415</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <th>topic classification</th>\n",
       "      <td>0.946775</td>\n",
       "      <td>0.973819</td>\n",
       "      <td>0.973765</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table>\n",
       "</div>"
      ],
      "text/plain": [
       "                                           MCC  Accuracy        F1\n",
       "Column       Task                                                 \n",
       "base_nli     event extraction         0.538591  0.753841  0.753169\n",
       "             hatespeech and toxicity  0.550569  0.858095  0.845361\n",
       "             stance detection         0.530711  0.775285  0.770202\n",
       "             topic classification     0.871001  0.935212  0.934550\n",
       "large_nli    event extraction         0.723042  0.852304  0.852599\n",
       "             hatespeech and toxicity  0.553551  0.854430  0.848143\n",
       "             stance detection         0.585007  0.797717  0.789429\n",
       "             topic classification     0.896400  0.948081  0.947663\n",
       "base_debate  event extraction         0.765923  0.878492  0.878934\n",
       "             hatespeech and toxicity  0.856644  0.950700  0.950405\n",
       "             stance detection         0.938404  0.970158  0.970135\n",
       "             topic classification     0.929391  0.965387  0.965351\n",
       "large_debate event extraction         0.819049  0.909218  0.909492\n",
       "             hatespeech and toxicity  0.882548  0.959694  0.959374\n",
       "             stance detection         0.969009  0.984979  0.984972\n",
       "             topic classification     0.924496  0.962503  0.962322\n",
       "llama        event extraction         0.808244  0.905726  0.905480\n",
       "             hatespeech and toxicity  0.559060  0.782145  0.799067\n",
       "             stance detection         0.605609  0.798117  0.799651\n",
       "             topic classification     0.918734  0.959396  0.959505\n",
       "llama70b     event extraction         0.844209  0.921788  0.922018\n",
       "             hatespeech and toxicity  0.600318  0.815456  0.828349\n",
       "             stance detection         0.762784  0.882235  0.882859\n",
       "             topic classification     0.944122  0.972265  0.972317\n",
       "sonnet       event extraction         0.784838  0.880936  0.881063\n",
       "             hatespeech and toxicity  0.571282  0.862425  0.853626\n",
       "             stance detection         0.791883  0.899259  0.898415\n",
       "             topic classification     0.946775  0.973819  0.973765"
      ]
     },
     "execution_count": 74,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "metrics(docs, preds = ['base_nli', 'large_nli', 'base_debate', 'large_debate', 'llama', 'llama70b', 'sonnet'], group_by = 'task')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
